Self-evolving vision transformer for chest X-ray diagnosis through knowledge distillation

Although deep learning-based computer-aided diagnosis systems have recently achieved expert-level performance, developing a robust model requires large, high-quality data with annotations that are expensive to obtain. This situation poses a conundrum that annually-collected chest x-rays cannot be utilized due to the absence of labels, especially in deprived areas. In this study, we present a framework named distillation for self-supervision and self-train learning (DISTL) inspired by the learning process of the radiologists, which can improve the performance of vision transformer simultaneously with self-supervision and self-training through knowledge distillation. In external validation from three hospitals for diagnosis of tuberculosis, pneumothorax, and COVID-19, DISTL offers gradually improved performance as the amount of unlabeled data increase, even better than the fully supervised model with the same amount of labeled data. We additionally show that the model obtained with DISTL is robust to various real-world nuisances, offering better applicability in clinical setting.

configuration can be utilized in the framework with a siamese design where one model learns from the prediction of the other model instead of labels, some lines of semi-and self-supervised learning works utilized knowledge distillation as mentioned above in the application for self-training. Of note, several recent studies have suggested the possibility that the model can obtain a performance similar to or better than the fully supervised model through semi-or self-supervised learning methods based on the knowledge distillation framework 3, 6 . Details of DISTL algorithm Pre-training for task-relevant findings. As transfer learning from the relevant tasks can significantly improve the performances, we first trained the model to classify the common CXR findings using a large corpus of the CheXpert dataset 9 , containing over 220,000 CXRs and corresponding labels. Among 10 common CXR findings, lung opacity, consolidation, edema, pneumonia, and pleural effusion were selected as the task-relevant classes based on the clinician's opinion. With the pre-training for the task-relevant CXR features on massive CXR data, the model can have excellent generalization capability in the external validation. In addition, this pre-training step enables the model to have robust feature extracting capacity with the prior knowledge of CXR. Without the pre-training, the model performances significantly decreased, especially for early time T (Supplementary Fig. S3).
Self-evolving training scheme. Details of our training scheme are illustrated in Fig. 2. First, with the small labeled data D l = {(x 1 , y 1 ), (x 2 , y 2 ), ...(x n , y n )}, the initial model is built with supervised learning. Then, we used this model as the initial teacher, and let the student learn from the teacher using the proposed DISTL method. In addition, to prevent the student from performance deterioration caused by the wrong estimations, the supervised correction with initial small labeled data is done per N steps. The updated models at the end of the DISTL are utilized as the starting point of the next-generation model, similar to the previous self-training approach 6 . Specifically, as the amount of available unlabeled data D T u increases over time T = 1, 2, ...T , the updated teacher and student models g T θt , g T θs at the end of the DISTL for the current T are used as the starting point of the next teacher and student models g T +1 θt , g T +1 θs for T + 1.
Loss functions for self-supervised and self-training. The overall framework of our method shown in Fig. 1c has a similar configuration to recent self-supervised ( Fig. 1b) 3, 10 and self-training approaches ( Fig. 1a) 6 , which also share similarities of the knowledge distillation for teacherstudent learning.
Specifically, both the teacher and student models share the same network architecture. The network architecture g parameterized by θ is composed of backbone f (before the final linear classifier) and of two heads h cls , h ss for disease classification and self-supervision, respectively: Given an input image x, the models in (1) yield two predictions P cls θ (x) and P ss θ (x) with the dimensions K cls and K ss , respectively, by normalizing the network output with the softmax function with temperature parameter τ cls and τ ss : where where τ controls the sharpness of the output distribution.
Then, we train a student model parameterized by θ s to match the prediction of teacher model parameterized θ t . Specifically, for a given input image x, a set V = {x o , x g , x l 1 , ...x l L } was constructed, containing one original view x o without augmentation, one global view x g and L local views x l of smaller size, where one global crop and multiple local crops are obtained with the multi-crop strategy 11 and random augmentations to construct differently distorted views. We used the set V in two ways. First, the clean original view is passed to the teacher, while the global view with weak augmentation and noises is passed to the student model. Then, the student is trained to mimic the pseudo-label generated by the teacher which minimizes the cross-entropy: Secondly, the original and global views are passed through the teacher while all views are passed through the student, thereby encouraging global-local correspondence with the following optimization problem: We found that optimizing this self-supervised term can encourage the model to learn the taskagnostic semantic features of the CXR (Supplementary Fig. S2), which implies that the model better attends to the shape of the CXR as a human reader does.
Combined Eq. (4) and Eq. (5) together, the final optimization problem can be defined using the weighted combination.
where α is a hyperparameter to adjust the weights between classification loss and self-supervising loss.
Unlike the noisy self-training where the teacher model remains unchanged during training, we built a momentum teacher using an exponential moving average (EMA) on the student weights, where λ follows a cosine scheduling: This encourages the update of the student to slowly pervade on the weights of the teacher with momentum, enabling the teacher to improve its performance gradually in accordance with the student as well averting the performance deterioration from the student misguided by wrong predictions.
The two loss functions in (6) deserve further discussion. By minimizing the first term L cls , the noised student is trained to be consistent with the pseudo-label generated by the clean teacher. Adding noise brings an important benefit of forcing invariance in the decision function, as it enforces the student to have prediction consistency across a variously augmented version of a given image. In addition, it can also add robustness to the common corruption and perturbations 6 .
The minimization of the second term, L ss , enforces the global-local correspondence so that the model can learn self-supervised features explicitly containing object boundaries and the semantic information 3 . These task-agnostic self-supervised features provide a useful shape-bias like humans that helps avoid the overfitting to the texture and other non-informative characteristics of the image, resulting in the improvement of generalization performance and the stability of the model.

Ablation Study
We performed the ablation studies to clarify the role of each component in the proposed DISTL framework (Supplementary Fig. S3).
Pre-training on task-relevant CXR features. Learning general but task-relevant CXR features from the pre-training on a large data corpus is one of the key components of our method. As shown in Supplementary Fig. S3, when not utilizing the pre-trained weights as an initialization point, the performance was suboptimal, while the gradual performance increased with the proposed framework under the increasing amount of unlabeled data was maintained.
Role of two loss terms. As the loss function of the proposed method consists of two terms, we ablated each term and evaluated the effect. As shown in Supplementary Fig. S3, when not using the self-supervising loss term, the performance was not improved over increasing T and even decreased at later T , devastating the merit of our framework. Similarly, when training the model only with the self-supervising loss term and fine-tuning with the label data at the correction step, the model performance was overall lower than that trained with both terms. Combined, these results suggest that both terms are necessary for the proposed framework to achieve stably improving performance with the increasing unlabeled data.
Correction step with label. The correction step within the proposed framework plays a role in restoring erroneously updated weights from the wrong estimation of the teacher with the initial set of small labeled data. To verify its contribution to the performance, we performed ablation of the correction step (see Supplementary Fig. S3). While the performance improvement was observed compared to the baseline, the performance improvement over increasing T was lower than that with the correction step. Supplementary